{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "024a10c2-6511-4bbf-a789-a12952d57988",
   "metadata": {},
   "source": [
    "# Lab 1, Sampling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "700e687c",
   "metadata": {
    "height": 233,
    "tags": []
   },
   "outputs": [],
   "source": [
    "from typing import Dict, Tuple\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import models, transforms\n",
    "from torchvision.utils import save_image, make_grid\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.animation import FuncAnimation, PillowWriter\n",
    "import numpy as np\n",
    "from IPython.display import HTML\n",
    "from diffusion_utilities import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c0d229a",
   "metadata": {},
   "source": [
    "# Setting Things Up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23507e17",
   "metadata": {
    "height": 1355,
    "tags": []
   },
   "outputs": [],
   "source": [
    "class ContextUnet(nn.Module):\n",
    "    def __init__(self, in_channels, n_feat=256, n_cfeat=10, height=28):  # cfeat - context features\n",
    "        super(ContextUnet, self).__init__()\n",
    "\n",
    "        # number of input channels, number of intermediate feature maps and number of classes\n",
    "        self.in_channels = in_channels\n",
    "        self.n_feat = n_feat\n",
    "        self.n_cfeat = n_cfeat\n",
    "        self.h = height  #assume h == w. must be divisible by 4, so 28,24,20,16...\n",
    "\n",
    "        # Initialize the initial convolutional layer\n",
    "        self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)\n",
    "\n",
    "        # Initialize the down-sampling path of the U-Net with two levels\n",
    "        self.down1 = UnetDown(n_feat, n_feat)        # down1 #[10, 256, 8, 8]\n",
    "        self.down2 = UnetDown(n_feat, 2 * n_feat)    # down2 #[10, 256, 4,  4]\n",
    "        \n",
    "         # original: self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())\n",
    "        self.to_vec = nn.Sequential(nn.AvgPool2d((4)), nn.GELU())\n",
    "\n",
    "        # Embed the timestep and context labels with a one-layer fully connected neural network\n",
    "        self.timeembed1 = EmbedFC(1, 2*n_feat)\n",
    "        self.timeembed2 = EmbedFC(1, 1*n_feat)\n",
    "        self.contextembed1 = EmbedFC(n_cfeat, 2*n_feat)\n",
    "        self.contextembed2 = EmbedFC(n_cfeat, 1*n_feat)\n",
    "\n",
    "        # Initialize the up-sampling path of the U-Net with three levels\n",
    "        self.up0 = nn.Sequential(\n",
    "            nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, self.h//4, self.h//4), # up-sample  \n",
    "            nn.GroupNorm(8, 2 * n_feat), # normalize                       \n",
    "            nn.ReLU(),\n",
    "        )\n",
    "        self.up1 = UnetUp(4 * n_feat, n_feat)\n",
    "        self.up2 = UnetUp(2 * n_feat, n_feat)\n",
    "\n",
    "        # Initialize the final convolutional layers to map to the same number of channels as the input image\n",
    "        self.out = nn.Sequential(\n",
    "            nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1), # reduce number of feature maps   #in_channels, out_channels, kernel_size, stride=1, padding=0\n",
    "            nn.GroupNorm(8, n_feat), # normalize\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(n_feat, self.in_channels, 3, 1, 1), # map to same number of channels as input\n",
    "        )\n",
    "\n",
    "    def forward(self, x, t, c=None):\n",
    "        \"\"\"\n",
    "        x : (batch, n_feat, h, w) : input image\n",
    "        t : (batch, n_cfeat)      : time step\n",
    "        c : (batch, n_classes)    : context label\n",
    "        \"\"\"\n",
    "        # x is the input image, c is the context label, t is the timestep, context_mask says which samples to block the context on\n",
    "\n",
    "        # pass the input image through the initial convolutional layer\n",
    "        x = self.init_conv(x)\n",
    "        # pass the result through the down-sampling path\n",
    "        down1 = self.down1(x)       #[10, 256, 8, 8]\n",
    "        down2 = self.down2(down1)   #[10, 256, 4, 4]\n",
    "        \n",
    "        # convert the feature maps to a vector and apply an activation\n",
    "        hiddenvec = self.to_vec(down2)\n",
    "        \n",
    "        # mask out context if context_mask == 1\n",
    "        if c is None:\n",
    "            c = torch.zeros(x.shape[0], self.n_cfeat).to(x)\n",
    "            \n",
    "        # embed context and timestep\n",
    "        cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1)     # (batch, 2*n_feat, 1,1)\n",
    "        temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)\n",
    "        cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)\n",
    "        temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)\n",
    "        #print(f\"uunet forward: cemb1 {cemb1.shape}. temb1 {temb1.shape}, cemb2 {cemb2.shape}. temb2 {temb2.shape}\")\n",
    "\n",
    "\n",
    "        up1 = self.up0(hiddenvec)\n",
    "        up2 = self.up1(cemb1*up1 + temb1, down2)  # add and multiply embeddings\n",
    "        up3 = self.up2(cemb2*up2 + temb2, down1)\n",
    "        out = self.out(torch.cat((up3, x), 1))\n",
    "        return out\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54c3a942",
   "metadata": {
    "height": 250,
    "tags": []
   },
   "outputs": [],
   "source": [
    "# hyperparameters\n",
    "\n",
    "# diffusion hyperparameters\n",
    "timesteps = 500\n",
    "beta1 = 1e-4\n",
    "beta2 = 0.02\n",
    "\n",
    "# network hyperparameters\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else torch.device('cpu'))\n",
    "n_feat = 64 # 64 hidden dimension feature\n",
    "n_cfeat = 5 # context vector is of size 5\n",
    "height = 16 # 16x16 image\n",
    "save_dir = './weights/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a705d0a8",
   "metadata": {
    "height": 114,
    "tags": []
   },
   "outputs": [],
   "source": [
    "# construct DDPM noise schedule\n",
    "b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1\n",
    "a_t = 1 - b_t\n",
    "ab_t = torch.cumsum(a_t.log(), dim=0).exp()    \n",
    "ab_t[0] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bc9001e",
   "metadata": {
    "height": 63,
    "tags": []
   },
   "outputs": [],
   "source": [
    "# construct model\n",
    "nn_model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f265f9c6",
   "metadata": {},
   "source": [
    "# Sampling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fa7aa8a",
   "metadata": {
    "height": 148,
    "tags": []
   },
   "outputs": [],
   "source": [
    "# helper function; removes the predicted noise (but adds some noise back in to avoid collapse)\n",
    "def denoise_add_noise(x, t, pred_noise, z=None):\n",
    "    if z is None:\n",
    "        z = torch.randn_like(x)\n",
    "    noise = b_t.sqrt()[t] * z\n",
    "    mean = (x - pred_noise * ((1 - a_t[t]) / (1 - ab_t[t]).sqrt())) / a_t[t].sqrt()\n",
    "    return mean + noise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0c30c67",
   "metadata": {
    "height": 97,
    "tags": []
   },
   "outputs": [],
   "source": [
    "# load in model weights and set to eval mode\n",
    "nn_model.load_state_dict(torch.load(f\"{save_dir}/model_trained.pth\", map_location=device))\n",
    "nn_model.eval()\n",
    "print(\"Loaded in Model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d31547d",
   "metadata": {
    "height": 437,
    "tags": []
   },
   "outputs": [],
   "source": [
    "# sample using standard algorithm\n",
    "@torch.no_grad()\n",
    "def sample_ddpm(n_sample, save_rate=20):\n",
    "    # x_T ~ N(0, 1), sample initial noise\n",
    "    samples = torch.randn(n_sample, 3, height, height).to(device)  \n",
    "\n",
    "    # array to keep track of generated steps for plotting\n",
    "    intermediate = [] \n",
    "    for i in range(timesteps, 0, -1):\n",
    "        print(f'sampling timestep {i:3d}', end='\\r')\n",
    "\n",
    "        # reshape time tensor\n",
    "        t = torch.tensor([i / timesteps])[:, None, None, None].to(device)\n",
    "\n",
    "        # sample some random noise to inject back in. For i = 1, don't add back in noise\n",
    "        z = torch.randn_like(samples) if i > 1 else 0\n",
    "\n",
    "        eps = nn_model(samples, t)    # predict noise e_(x_t,t)\n",
    "        samples = denoise_add_noise(samples, i, eps, z)\n",
    "        if i % save_rate ==0 or i==timesteps or i<8:\n",
    "            intermediate.append(samples.detach().cpu().numpy())\n",
    "\n",
    "    intermediate = np.stack(intermediate)\n",
    "    return samples, intermediate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4f7e888-2abf-4394-a86d-100805e92fff",
   "metadata": {
    "height": 114,
    "tags": []
   },
   "outputs": [],
   "source": [
    "# visualize samples\n",
    "plt.clf()\n",
    "samples, intermediate_ddpm = sample_ddpm(32)\n",
    "animation_ddpm = plot_sample(intermediate_ddpm,32,4,save_dir, \"ani_run\", None, save=False)\n",
    "HTML(animation_ddpm.to_jshtml())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f0ea7da-1a9c-4213-b508-38f57cd6dc9d",
   "metadata": {},
   "source": [
    "#### Demonstrate incorrectly sample without adding the 'extra noise'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22c13cb6",
   "metadata": {
    "height": 437,
    "tags": []
   },
   "outputs": [],
   "source": [
    "# incorrectly sample without adding in noise\n",
    "@torch.no_grad()\n",
    "def sample_ddpm_incorrect(n_sample):\n",
    "    # x_T ~ N(0, 1), sample initial noise\n",
    "    samples = torch.randn(n_sample, 3, height, height).to(device)  \n",
    "\n",
    "    # array to keep track of generated steps for plotting\n",
    "    intermediate = [] \n",
    "    for i in range(timesteps, 0, -1):\n",
    "        print(f'sampling timestep {i:3d}', end='\\r')\n",
    "\n",
    "        # reshape time tensor\n",
    "        t = torch.tensor([i / timesteps])[:, None, None, None].to(device)\n",
    "\n",
    "        # don't add back in noise\n",
    "        z = 0\n",
    "\n",
    "        eps = nn_model(samples, t)    # predict noise e_(x_t,t)\n",
    "        samples = denoise_add_noise(samples, i, eps, z)\n",
    "        if i%20==0 or i==timesteps or i<8:\n",
    "            intermediate.append(samples.detach().cpu().numpy())\n",
    "\n",
    "    intermediate = np.stack(intermediate)\n",
    "    return samples, intermediate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84d4ee62",
   "metadata": {
    "height": 114,
    "tags": []
   },
   "outputs": [],
   "source": [
    "# visualize samples\n",
    "plt.clf()\n",
    "samples, intermediate = sample_ddpm_incorrect(32)\n",
    "animation = plot_sample(intermediate,32,4,save_dir, \"ani_run\", None, save=False)\n",
    "HTML(animation.to_jshtml())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ace3580b-d35c-4b7f-9b6f-dab2e0a850c9",
   "metadata": {},
   "source": [
    "# Acknowledgments\n",
    "Sprites by ElvGames, [FrootsnVeggies](https://zrghr.itch.io/froots-and-veggies-culinary-pixels) and  [kyrise](https://kyrise.itch.io/)   \n",
    "This code is modified from, https://github.com/cloneofsimo/minDiffusion   \n",
    "Diffusion model is based on [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) and [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83601749-c2ff-4158-a58c-057a4d233281",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
